/*
 * Copyright (C) 2021 Samsung Electronics Co. LTD
 *
 * This software is a property of Samsung Electronics.
 * No part of this software, either material or conceptual may be copied or distributed, transmitted,
 * transcribed, stored in a retrieval system, or translated into any human or computer language in any form by any
 * means, electronic, mechanical, manual or otherwise, or disclosed to third parties without the express written
 * permission of Samsung Electronics. (Use of the Software is restricted to non-commercial, personal or academic,
 * research purpose only)
 */

#pragma once

#include <torch/script.h>
#include <experimental/filesystem>
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/utils.h"

namespace fs = std::experimental::filesystem;

namespace nn_compiler
{
namespace runtime
{
namespace utils
{
#define SPLITE_LINE "================================================"

static std::string showTensorInfo(const torch::Tensor& tensor)
{
    std::stringstream ss;
    ss << " Shape:" << tensor.sizes() << " Dim:" << tensor.dim() << " Dtype:" << tensor.dtype()
       << " Numel:" << tensor.numel() << " Device:" << tensor.device();
    return ss.str();
}

static bool checkTensorEqual(const torch::Tensor& tensor1, const torch::Tensor& tensor2, bool print_tensor = false)
{
    DLOG(INFO) << "tensor1:" << showTensorInfo(tensor1);
    DLOG(INFO) << "tensor2:" << showTensorInfo(tensor2);
    if (print_tensor) {
        DLOG(INFO) << SPLITE_LINE;
        DLOG(INFO) << tensor1;
        DLOG(INFO) << SPLITE_LINE;
        DLOG(INFO) << tensor2;
    }
    bool ret = tensor1.equal(tensor2);
    return ret;
}

class TVComparator
{
   public:
    ~TVComparator() { tv_.clear(); }
    TVComparator(const TVComparator&) = delete;
    TVComparator& operator=(const TVComparator&) = delete;
    static TVComparator& getInstance()
    {
        static TVComparator tv_comparator;
        return tv_comparator;
    }

   private:
    TVComparator() {}

   public:
    /**
     * @brief Load Test vector from binary file, the Tensor of test vector are loaded as default CPUTensor. If loaded,
     * the tensor are saved into a unordered_map. binary file (.bin) it is not pytroch (.pt/.pth) file, since .pt file
     * failed to load in libtorch c++ API, parse .pt file may depend on python pickle module
     *
     * @param bin_file the binary file is `generated by tensor.cpu().numpy().tofile()`
     * @param shape the shape of test_vector
     * @param dtype the data type of test_vector
     * @param ans_key_name set a key/tag name for this test_vector
     */
    void loadTV(const std::string& file, const std::vector<int64_t>& shape, DataType dtype,
                const std::string& ans_key_name)
    {
        auto tensor = utils::loadTensor(file, shape, dtype);
        tv_.insert({ans_key_name, tensor});
    }

    void clear() { tv_.clear(); }

    /**
     * @brief Compare 2 tensors using torch.equal(), return true, only if both size & element value equal
     *
     * @param input_tensor the tensor needed to compare, since the ans_tensor is loaded as CPUTensor, so input_tensor
     * need to be CPUTensor also.
     * @param ans_key_name the key_name of answer, ans_tensor = unordered_map[key_name]
     * @param print_tensor if print the values
     * @return true
     * @return false
     */
    bool compare(const torch::Tensor& input_tensor, const std::string& ans_key_name, bool print_tensor = false)
    {
        auto iter = tv_.find(ans_key_name);
        assert(iter != tv_.end() && "Test vector not exist!");
        auto& other_tensor = iter->second;
        bool status = checkTensorEqual(input_tensor, other_tensor, print_tensor);
        if (status) {
            DLOG(INFO) << "TVComparator: Success !";
        } else {
            DLOG(INFO) << "TVComparator: Failed !";
        }
        return status;
    }

   private:
    std::unordered_map<std::string, torch::Tensor> tv_;

   public:
    void loadTVs()
    {
        // For HWR
        std::string base_dir = "handwritten-text-recognition/output_fp16_bs1/";
        std::string weight_dir = "handwritten-text-recognition/weight_fp16/";
        std::string buffer_dir = "handwritten-text-recognition/buffer_fp16/";

        // Conv_1 ~ Conv_5
        this->loadTV(base_dir + "output_conv_l1_1_16_1024_128.bin", {1, 16, 1024, 128}, DataType::FLOAT16,
                     "aten::conv2d_101");

        this->loadTV(base_dir + "output_conv_l2_1_32_512_64.bin", {1, 32, 512, 64}, DataType::FLOAT16,
                     "aten::conv2d_117");

        this->loadTV(base_dir + "output_conv_l3_1_48_256_32.bin", {1, 48, 256, 32}, DataType::FLOAT16,
                     "aten::conv2d_134");

        this->loadTV(base_dir + "output_conv_l4_1_64_128_16.bin", {1, 64, 128, 16}, DataType::FLOAT16,
                     "aten::conv2d_151");

        this->loadTV(base_dir + "output_conv_l5_1_80_128_16.bin", {1, 80, 128, 16}, DataType::FLOAT16,
                     "aten::conv2d_163");

        // LSTM1
        this->loadTV(base_dir + "output_lstm_1_128_512.bin", {1, 128, 512}, DataType::FLOAT16, "aten::lstm1_204");
        // LogSoftmax
        this->loadTV(base_dir + "output_log_softmax_1_128_98.bin", {1, 128, 98}, DataType::FLOAT16,
                     "aten::log_softmax_211");

        // Transpose
        this->loadTV(base_dir + "output_transpose_reshape_1_128_1280.bin", {1, 128, 1280}, DataType::FLOAT16,
                     "aten::reshape_173");
        this->loadTV(base_dir + "output_final_transpose_128_1_98.bin", {128, 1, 98}, DataType::FLOAT16,
                     "aten::transpose_212");

        // Final output
        this->loadTV(base_dir + "output_hwr_y_hat_128_1_98.bin", {128, 1, 98}, DataType::FLOAT16, "output_hwr");

        // BN1 ~ BN5
        this->loadTV(base_dir + "output_bn_l1_1_16_1024_128.bin", {1, 16, 1024, 128}, DataType::FLOAT16,
                     "aten::batch_norm_107");
        this->loadTV(base_dir + "output_bn_l2_1_32_512_64.bin", {1, 32, 512, 64}, DataType::FLOAT16, "");
        this->loadTV(base_dir + "output_bn_l3_1_48_256_32.bin", {1, 48, 256, 32}, DataType::FLOAT16, "");
        this->loadTV(base_dir + "output_bn_l4_1_64_128_16.bin", {1, 64, 128, 16}, DataType::FLOAT16, "");
        this->loadTV(base_dir + "output_bn_l5_1_80_128_16.bin", {1, 80, 128, 16}, DataType::FLOAT16, "");

        // LeakkRelu_1
        this->loadTV(base_dir + "output_leakyRelu_l1_1_16_1024_128.bin", {1, 16, 1024, 128}, DataType::FLOAT16,
                     "aten::leaky_relu_108");
        // MaxPool_1
        this->loadTV(base_dir + "output_pool_l1_1_16_512_64.bin", {1, 16, 512, 64}, DataType::FLOAT16,
                     "aten::max_pool2d_113");
    }
};

#define TV_ENABLE 0

#if TV_ENABLE
#define TV_LOAD_ALL()                                    \
    TVComparator& tv_comp = TVComparator::getInstance(); \
    tv_comp.loadTVs()

#define TV_COMPARE_TENSOR(input_tensor, ans_key_name, print_tensor) \
    TVComparator& tv_comp = TVComparator::getInstance();            \
    tv_comp.compare(input_tensor.cpu(), ans_key_name, print_tensor)

#else
#define TV_LOAD_ALL()
#define TV_COMPARE_TENSOR(input_tensor, ans_key_name, print_tensor)
#endif  // TV_ENABLE

}  // namespace utils
}  // namespace runtime
}  // namespace nn_compiler
